Parameter estimation from simulated data

Pre-requisites

Julia packages used:

using DifferentialEquations, Plots
using DifferentialEquations, Flux, DiffEqFlux, Optim, DiffEqSensitivity
import Statistics
using Turing, Distributions, DifferentialEquations
using MCMCChains, Plots, StatsPlots
using Random
using Logging

Simulating data

We use the same equations as before. This time we specify some times to sample the solution to obtain data to use in parameter estimation.

function droop!(du, u, p, t)
  R, Q, X = u
  Km, Vmax, Qmin, muMax, d, R0 = p
  rho = Vmax * R / (Km + R)
  mu = muMax * (1 - Qmin/Q)
  du[1] = dRdt = d*(R0 - R) - rho*X
  du[2] = dQdt = rho - mu*Q
  du[3] = dXdt = (mu - d)*X
end

The initial conditions, parameters, and time-span for the solution must be specified.

u0 = [1.0, 1.0, 1.0]
p = [0.1, 2.0, 1.0, 0.8, 0.0, 0.0]
tspan = (0.0, 10.0)
tsteps = [0.1, 3.2, 4.5, 7.0, 9.1]

Now we create and solve the ODE initial value problem. We will create two solutions, one smoothly estimating the solution with interpolation on an interval and one with samples at a few discrete points.

prob = ODEProblem(droop!, u0, tspan, p)
sol1 = solve(prob, Tsit5())
sol2 = solve(prob, Rosenbrock23(), saveat = tsteps)

We will plot the solution and the discrete samples.

Plots.plot(sol1)
Plots.scatter!(sol2)

We can convert the solution to a matrix and add some noise to the output.

data1 = Array(sol2)
data1 = data1 .* (1 .+ 0.05*randn(size(data1)));
Plots.plot(sol1, label = ['R' 'Q' 'X'])
Plots.scatter!(sol2.t, data1', label = ['R' 'Q' 'X'])

Optimization to find parameters

First we write a function to describe the difference between a trial solution and the data points. For the optimizer we use next, this should be a function of just the parameters to be adjusted. Here my loss function is the difference between solution and data, scaled by standard deviation of each variable in data, squared and summed.

function loss(p)
    tspan = (0.0, 10.0)
    u0 = p[1:3]
    param = [ p[4:7] ; 0.0 ; 0.0 ] # force d = Rin = 0.0
    prob = ODEProblem(droop!, u0, tspan, param)
    sol2 = solve(prob, Rosenbrock23(), saveat = tsteps)
    data2 = Array(sol2)
    # loss = sum(((1.0 ./ Statistics.std(data1, dims=2)') * (data2 .- data1) ) .^ 2 )
    # loss = sum(((1.0 ./ (1 .+ Statistics.std(data1, dims=2)')) * (data2 .- data1) ) .^ 2 )
    loss = sum( (data2 .- data1) .^ 2 )
  return loss , sol2
end
loss (generic function with 1 method)

(Note: I originally standardized the three variables as in the commented loss definition above, but this resulted in a worse solution, because R was so close to 0 and has a standard deviation of 10^{-7}. Perhaps a solution is to add a small amount to standard deviations to prevent this distortion. My first attempt, also commented out, did led to instabilities.)

The optimizer allows us to provide a callback function to show the loss score, or make a plot, at each iteration. Here's an example callback function.

callback = function (p, l, pred)
  display(l)
  # plt = plot(pred, ylim = (0, 6))
  # display(plt)
  # Tell sciml_train to not halt the optimization. If return true, then
  # optimization stops.
  return false
end
#1 (generic function with 1 method)

We can test the loss function.

loss([1.0; 1.0; 1.0; p])
(0.06793220570062616, t: [0.1, 3.2, 4.5, 7.0, 9.1]
u: [[0.8194448617123582, 1.1731174127897068, 1.0063561287792864], [-4.324744135240712e-13, 1.092662631142999, 1.8307612862597094], [-5.1339331824071424e-15, 1.0323880304289377, 1.9376392208132505], [4.375444133471588e-18, 1.00420411122562, 1.992008464727631], [-2.660209756869351e-19, 1.0006732783836856, 1.9990287483735543]])

We are now ready to use the optimizer. ADAM is the gradient-search method.

result_ode = DiffEqFlux.sciml_train(loss, [1.0, 1.0, 1.0, 0.1, 2.0, 1.0, 1.0, 0, 0],
                                    ADAM(0.001),  # this argument needs to be small or instability
                                    cb = callback,
                                    maxiters = 200)
u: 9-element Vector{Float64}:
 0.992108008950313
 1.0032707036649016
 1.0794707905497742
 0.07131417326456733
 2.028185142334497
 1.0748924965236635
 1.0145068941313244
 0.0
 0.0

Now we use the parameters from this optimization to solve the differential equation and plot the solution.

prob = ODEProblem(droop!, result_ode.u[1:3], tspan, result_ode.u[4:end])
sol3 = solve(prob, Rosenbrock23())

We will plot the solution and the discrete samples.

Plots.plot(sol1, label = ["Original R" 'Q' 'X'])
Plots.plot!(sol3, lw = 2, label = ["New R" 'Q' 'X'])
Plots.scatter!(sol2.t, data1', label = ["Data R" 'Q' 'X'])

Compare parameters.

xaxis = ["R0" "Q0" "X0" "Km" "Vmax" "Qmin" "muMax" "d" "R0"]
estimate = result_ode.u'
original = [u0; p]'
error = (estimate .- original ) ./ original
p1 = Plots.scatter(xaxis, estimate, legend = false, title = "Parameters")
p1 = Plots.scatter!(xaxis, original, legend = false)
p2 = Plots.scatter(xaxis[1:7], error[1:7], legend = false, title = "Relative error")
plot(p1, p2, layout = (2,1))

Bayesian MCMC parameter estimation

Set a random seed to make the result reproducible. Select an option for Turing package.

Random.seed!(14)
Turing.setadbackend(:forwarddiff)

Define a model for the parameters, including priors, solution of the ODE, and comparison between data and the solution.

@model function fitDroop1(t, R, Q, X)
    σ1 ~ InverseGamma(2, 3)  # positive support; parameters α, β; mean β/(α-1), here 3
    σ2 ~ InverseGamma(2, 3)
    R0 ~ truncated(Normal(1, 1), 0, 5)
    Q0 ~ truncated(Normal(1, 1), 0, 5)
    X0 ~ truncated(Normal(1, 1), 0, 5)
    Km ~ truncated(Normal(4.0, 2), 0, 5)
    Vmax ~ truncated(Normal(1.2, 2), 0, 5)
    Qmin ~ truncated(Normal(1.0, 2), 0, 5)
    muMax ~ truncated(Normal(1.0, 2), 0, 5)

    p = [ Km, Vmax, Qmin, muMax, 0.0, 0.0]

    # must define the problem with numeric values first, then update with distributions
    prob1 = ODEProblem(droop!, [R[1], Q[1], X[1]], (0.0, 10.0), [0.1, 1.0, 1.0, 1.0, 0.0, 0.0])
    prob = remake(prob1, u0=[R0, Q0, X0], p=p)
    predicted = solve(prob, Rosenbrock23(), saveat=t)

    for j = 1:length(t)
        Q[j] ~ Normal(predicted[j][2], σ1)
        X[j] ~ Normal(predicted[j][3], σ2)
    end
end

Create the model and simulate the chains.

t = sol2.t
R = [v[1] for v in sol2.u]
Q = [v[2] for v in sol2.u]
X = [v[3] for v in sol2.u]

model = fitDroop1(t, R, Q, X)
chain2 = sample(model, NUTS(.65), MCMCThreads(), 200, 4, progress=false)  # multi-threaded
# chain2 = mapreduce(c -> sample(model, NUTS(.75), 200), chainscat, 1:4) # single thread
┌ Info: Found initial step size
└   ϵ = 0.4
┌ Info: Found initial step size
└   ϵ = 0.05
┌ Info: Found initial step size
└   ϵ = 0.4
┌ Info: Found initial step size
└   ϵ = 0.0125
┌ Warning: The current proposal will be rejected due to numerical error(s).
│   isfinite.((θ, r, ℓπ, ℓκ)) = (true, false, false, false)
└ @ AdvancedHMC ~/.julia/packages/AdvancedHMC/MIxdK/src/hamiltonian.jl:47
┌ Warning: The current proposal will be rejected due to numerical error(s).
│   isfinite.((θ, r, ℓπ, ℓκ)) = (true, false, false, false)
└ @ AdvancedHMC ~/.julia/packages/AdvancedHMC/MIxdK/src/hamiltonian.jl:47

Extract data so we can plot trajectories from a selection of parameter values from the posterior distribution. Parameters come out of the chains in alphabetical order, so I resequence them to be in the order: initial conditons, parameter values as used in ODE function.

chain_array0 = Array(chain2);
chain_array = chain_array0[: ,[4, 2, 6, 1, 5, 3, 7] ]  # R Q X Km Vmax Qmin muMax
N = size(chain_array)[1]
N2 = Int(N/2)
pR = Plots.scatter(t, R);
for k in 1:300
    pars = [ chain_array[rand(N2:N), :];  0.0; 0.0 ] # append d, Rin
    resol = solve(remake(prob, u0 = pars[1:3], p = pars[4:end]), Rosenbrock23())
    pR = plot!(resol, vars=(0,1), alpha=0.3, color = "#BBBBBB", legend = false)
end
pR = plot!(sol1, vars=(0,1), alpha=1, color = "#BB0000", legend = false, ylims=(0, Inf),
        yguide = "R")
pQ = Plots.scatter(t, Q);
for k in 1:300
    pars = [ chain_array[rand(N2:N), :];  0.0; 0.0 ] # append d, Rin
    resol = solve(remake(prob, u0 = pars[1:3], p = pars[4:end]), Rosenbrock23())
    pQ = plot!(resol, vars=(0,2), alpha=0.3, color = "#BBBBBB", legend = false)
end
pQ = plot!(sol1, vars=(0,2), alpha=1, color = "#BB0000", legend = false, ylims=(0, Inf),
        yguide = "Q")
pX = Plots.scatter(t, log.(X));
for k in 1:300
    pars = [ chain_array[rand(N2:N), :];  0.0; 0.0 ] # append d, Rin
    resol = solve(remake(prob, u0 = pars[1:3], p = pars[4:end]), Rosenbrock23())
    pX = plot!(resol, vars=((t,x) -> (t, log.(x)), 0, 3), alpha=0.3, color = "#BBBBBB", legend = false)
end
pX = plot!(sol1, vars=((t,x) -> (t, log.(x)), 0,3), alpha=1, color = "#BB0000", legend = false,
        yguide = "log X")
plot(pR, pQ, pX, layout = (3,1))